CHESTNUT CODE IS POETRY

【并行策略】为什么Split和Gather实现序列并行之后要对梯度进行缩放

考虑$\mathbf{x}\in \mathbb{R}^{d_{in}\times B}$,其中$d_{in}$是输入维度,$B$是序列长度,$\mathbf{y}\in \mathbb{R}^{d_{out}\times B}$,其中$d_{out}$是输出维度。有权重$W\in \mathbb{R}^{d_{in}\times d_{out}}$

为了简化计算,我们考虑简单的 FFN,对$\mathbf{x}$乘上权重矩阵得到$\mathbf{y}$

$$ \mathbf{y}= \mathbf{W}^T \mathbf{x}\\ [y]_{i,b} = [\mathbf{w}]_{:, i}^T [\mathbf{x}]_{:,b} $$

假设$\mathbf{y}$的优化目标是$\hat{\mathbf{y}}$。有损失函数:

$$ L = \frac{1}{B d_{out}} \sum_{i=1}^{d_{out}} \sum_{b=1}^B \|\mathbf{y}_{i,b}-\hat{\mathbf{y}}_{i,b}\|^2 $$

为了简化,我们只对一个权重列向量$\mathbf{w}_i \in \mathbb{R}^{d_{in} \times 1}, i=1,2,\dots,d_{out}$考虑。权重列向量的梯度为:

$$ \frac{\partial L}{\partial \mathbf{w}_i} = \sum_{b=1}^B \frac{\partial L}{\partial \mathbf{y}_{i,b}} \frac{\partial \mathbf{y}_{i,b}}{\partial \mathbf{w}_i} = \frac{1}{B d_{out}} \sum_{b=1}^B 2(\mathbf{y}_{i,b}-\hat{\mathbf{y}}_{i,b}) \mathbf{x}_{:,b} $$

假设并行度$G=2$,我们在输入FFN之前对输入$\mathbf{x}$$B$轴上进行了并行,计算损失的时候又将$B$轴上的并行结果 gather 起来,得到$\mathbf{y}$

这样就会导致在每一个rank上,权重列向量的梯度发生改变:

$$ (\frac{\partial L}{\partial \mathbf{w}_i})_{g'} = \frac{1}{B d_{out}} \sum_{b=1}^{B/G} 2(\mathbf{y}_{i,b}-\hat{\mathbf{y}}_{i,b}) \mathbf{x}_{:,b} $$

这样会导致随着并行度越高,梯度会按照$1/G$的比例缩小。因此为了抵消并行度带来的影响,在Gather的反向传播中,除了要将梯度split外,还要加入缩放因子:

$$ (\frac{\partial L}{\partial \mathbf{w}_i})_{g} = G(\frac{\partial L}{\partial \mathbf{w}_i})_{g'} $$

同样Split的反向传播中,梯度会按照$G$的比例放大,也需要加入缩放因子:

$$ (\frac{\partial L}{\partial \mathbf{x}_{:,b}})_{s} = \frac{1}{G} (\frac{\partial L}{\partial \mathbf{y}_{i,b}})_{s'} $$

这样就可以抵消并行度带来的梯度缩放问题。

Return to Blog